import torch
import torch.nn as nn
import numpy as np
from deepset import *
from data_utils import *
from cdm import *
from collections import defaultdict
import itertools

num_players = 36

def str_to_idx(s):
    indictor = [int(x) for x in s.split()]
    idx = np.where(np.array(indictor) == 1)[0]
    return idx

def convert_str(lst):
    return [str(x) for x in lst]

def score_fn(model, idxs):
    indicator_vector = np.zeros(num_players)
    indicator_vector[idxs] = 1
    return model(torch.FloatTensor(indicator_vector).view(1,-1)).detach().numpy()[0][0]

def print_team(model, sym=True):
    mat = model.weight.weight.data.detach().numpy()
    if sym: mat = mat + mat.T
    num_players = np.shape(mat)[0]
    for i in range(0, num_players, 3):
        print(mat[i:i+3, i:i+3])

nm = "data/spread_data12.txt"
s = Single_DataSet(nm, num_players=num_players)
res = [(" ".join(a.split()[:-1]), float(a.split()[-1])) for a in s.data]
res.sort(key=lambda x: x[1])
list(filter(lambda x: "1 1 1" in x[0], res))
res = [(str_to_idx(s), val) for s,val in res]
all_idxs = range(num_players)
teams = np.array_split(all_idxs, num_players / 3)

##################################################################################

'''
Load models
'''

model = FHoi_single(num_players)
filepath = "example_models/spread_fhoi.pth"
model.load_state_dict(torch.load(filepath))

model = DeepSet_single(num_players, embed_dim=20)
filepath = "example_models/spread_deepset.pth"
model.load_state_dict(torch.load(filepath))

model = LR_single(num_players)
filepath = "example_models/spread_linear.pth"
model.load_state_dict(torch.load(filepath))

model = CDM_single(num_players, 5)
filepath = "example_models/spread_cdm5.pth"
model.load_state_dict(torch.load(filepath))

def get_cdm_weight(model):
    cs = model.cs.weight.data.detach().numpy()
    ts = model.ts.weight.data.detach().numpy()
    prod = np.matmul(cs.T, ts)
    off_diag = 1 - np.eye(cs.shape[1])
    return prod * off_diag 

V = get_cdm_weight(model)
np.save(open("cdm_V.npy","wb"), V)

##################################################################################

'''
Compute optimal rest of the team given some team members
'''

out = []
pred_percent = []
baseline_percent = []
for player_no in range(num_players):
    teams = np.array_split(all_idxs, num_players / 3)
    team = teams[int(player_no / 3)]
    all_idx = list(range(num_players))
    for i in team: all_idx.remove(i) 
    
    mx = float("-inf")
    best_team = None
    for j, k in itertools.combinations(all_idx, 2):
        val = score_fn(model, [player_no, j, k])
        if val > mx: 
            mx = val
            best_team = [player_no, j, k]
    
    ranks = list(filter(lambda x: set(team).intersection(set(x[0])) == set([player_no]), res))
    ranked_teams = dict([(tuple(t), val) for t,val in ranks])

    best_team.sort()
    pred_val = ranked_teams[tuple(best_team)]
    ranks.sort(key=lambda x: x[1])

    actual_val = ranks[-1][1]
    assert(actual_val > 0)
    pred_percent.append( np.abs(pred_val / actual_val))

    med_val = ranks[int(len(ranks) / 2)][1]
    baseline_percent.append(med_val / actual_val)
   
    out.append((str(player_no), " ".join(convert_str(best_team)), " ".join(convert_str(ranks[-1][0])), str(pred_val), str(actual_val), str(med_val)))

with open("select_best_pair_cdm.csv", "w") as f:
    out = [",".join(x) for x in out]
    f.write("\n".join(out))

#########################################################################################

out = []
pred_percent = []
baseline_percent = []
for t in teams:
    for duo in itertools.combinations(t, 2): 
        idxs = list(range(num_players))
        i, j = duo
        #for l in duo: idxs.remove(l)
        for l in t: idxs.remove(l)
    
        mx = float("-inf")
        best_team = None
        for k in idxs:
            val = score_fn(model, [i, j, k])
            if val > mx:
                mx = val
                best_team = [i, j, k]
        best_team.sort()    

        ranks = list(filter(lambda x: set(x[0]).intersection(set(t)) == set([i,j]), res))
        ranks.sort(key=lambda x: x[1])

        ranked_teams = dict([(tuple(t), val) for t,val in ranks])

        actual_val = ranks[-1][1]
        assert(actual_val > 0)
        pred_val = ranked_teams[tuple(best_team)]
        med_val = ranks[int(len(ranks) / 2)][1]

        pred_percent.append( pred_val / actual_val)
        baseline_percent.append(med_val / actual_val)

        out.append((" ".join(convert_str(duo)), " ".join(convert_str(best_team)), " ".join(convert_str(ranks[-1][0])), str(pred_val), str(actual_val), str(med_val)))

with open("select_best_single_cdm.csv", "w") as f:
    out = [",".join(x) for x in out]
    f.write("\n".join(out))

